import pandas as p
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from AnalyzeHistoricalNHTS import getVehicleDataByYear, getPersonDataByYear, getHouseholdDataByYear, get_state_safety_inspection_df_by_year, get_state_safety_inspection_dict
from tqdm import tqdm

if __name__=='__main__':
    year = 2017
    vehData = getVehicleDataByYear(year)
    perData = getPersonDataByYear(year)
    # hhData = getHouseholdDataByYear(year)
    safety_inspection_df = get_state_safety_inspection_df_by_year(year)

    perDataSmall = perData.loc[:, ['HOUSEID', 'PERSONID', 'R_AGE']]
    # perDataSmall.index = perData.loc[:, ['HOUSEID', 'PERSONID']]

    perDict = {}

    for row in tqdm(perDataSmall.itertuples(index=False)):
        tempKey = (row[0], row[1])
        tempVal = row[2]

        perDict[tempKey] = tempVal

    driver_age_col = []
    driver_state_col = []

    veh_house_col_ind = list(vehData.columns).index('HOUSEID')
    veh_driver_col_ind = list(vehData.columns).index('PERSONID')

    for row in tqdm(vehData.itertuples(index=False)):
        try:
            temp_age = perDict[(row[veh_house_col_ind], row[veh_driver_col_ind])]
        except KeyError:
            temp_age = -1
            pass

        driver_age_col.append(temp_age)


    vehData.loc[:, 'Driver Age'] = driver_age_col

    vehData.loc[:, 'New'] = 0
    vehData.loc[vehData.loc[:, 'VEHYEAR']>=2016, 'New'] = 1

    vehData = vehData.loc[(vehData.loc[:, 'Driver Age']>=16)&(vehData.loc[:, 'VEHAGE']>=0), :]
    vehData = vehData.merge(safety_inspection_df, left_on='HHSTATE', right_on='State Code')

    x = 5
    gr = (1+np.sqrt(5))/2
    y = x/gr



    vehData.loc[:, 'Subject to Safety Inspection'] = vehData.loc[:, 'HasSafety']

    vehData_max = vehData.groupby(['HOUSEID', 'WHOMAIN'], as_index=False).max()

    # fig, ax = plt.subplots(1,1,figsize=(x,y))
    # noSafetyColor = 'xkcd:red'
    # safetyColor = 'xkcd:blue'
    # ax = sns.lineplot(data=vehData, x='Driver Age', y='VEHAGE', hue='Subject to Safety Inspection', weights='WTHHFIN',
    #                   palette=[noSafetyColor, safetyColor], legend=False)
    # plt.ylabel('Vehicle Age')
    # plt.xlabel('Main Driver Age')
    # plt.text(50, 12.75, 'Drivers not subject\nto safety inspections', color=noSafetyColor, ha='center')
    # plt.text(70, 7.75, 'Drivers subject\nto safety inspections', color=safetyColor, ha='center')
    # plt.title('Vehicle Age vs. Main Driver Age\nBy Safety Inspection Requirement')
    # plt.savefig('Plots/VehicleAgeVsDriverAge.png', bbox_inches='tight')
    # plt.show()

    fig, ax = plt.subplots(1,1,figsize=(x,y))
    noSafetyColor = 'xkcd:red'
    safetyColor = 'xkcd:blue'
    ax = sns.lineplot(data=vehData, x='Driver Age', y='New', hue='Subject to Safety Inspection', weights='WTHHFIN',
                      palette=[noSafetyColor, safetyColor], legend=False)
    plt.ylabel('Share Owning New Vehicle')
    plt.xlabel('Main Driver Age')
    plt.text(40, .15, 'Drivers not subject\nto safety inspections', color=noSafetyColor, ha='center')
    plt.text(50, .00, 'Drivers subject\nto safety inspections', color=safetyColor, ha='center')
    plt.title('Share Owns New Vehicle vs. Main Driver Age\nBy Safety Inspection Requirement')
    plt.savefig('Plots/ShareNewVehicleVsDriverAge.png', bbox_inches='tight')
    plt.show()

